from TransientImage import TransientImage

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.widgets import Slider, Button
from mpl_toolkits.axes_grid1 import make_axes_locatable

class TimeSlicePlot:
	image = None #type: TransientImage
	name = None #type: str
	figure = None
	img = None
	
	# displaying properties
	vMin = None #type: float
	vMax = None #type: float
	log = None #type: bool
	bin = None #type: int
	minDL = None #type: float
	maxDL = None #type: float
	def __init__(self, image : TransientImage, name : str="") -> None:
		self.image = image
		self.name = name
		
		self.vMin = np.min(self.image.data)
		self.vMax = np.max(self.image.data)
		self.bin = 0
		self.log = False
		
		
		# find min/max values for slider
		d = self.image.data.flatten()
		nonzeroData = d[np.nonzero(d)]
		if(0 == nonzeroData.size):
			raise Exception("No signal in transient image!")
		self.minDL = np.log(np.min(nonzeroData))
		self.maxDL = np.log(np.max(nonzeroData))+1
		# the integrated image might have smaller/greater values:
		summed = np.sum(self.image.data, axis=2) / float(self.image.data.shape[2])
		nonzeroData = summed[np.nonzero(summed)]
		self.minDL = min(self.minDL, np.log(np.min(nonzeroData)))
		self.maxDL = max(self.maxDL, np.log(np.max(nonzeroData))+1)
		
		# mpl stuff:
		self.figure = plt.figure("Time slices "+self.name)
		
		# the image
		imgAxes = self.figure.add_subplot(111)
		self.img = imgAxes.imshow(self.image.data[:,:,0])
		divider = make_axes_locatable(imgAxes)
		colorbarAxes = divider.append_axes("right", size="5%", pad=0.05)
		self.figure.colorbar(self.img, cax=colorbarAxes)
		self.Redraw()

		# time slider
		sliderAxes = divider.append_axes("bottom", size="5%", pad=0.3)
		self.sliderTime = Slider(sliderAxes, "$t$", self.image.tMin, self.image.tMax, valinit=self.image.tMin, closedmax=False)
		self.sliderTime.on_changed(self.TimeSliderUpdate)
		
		# percentile slider
		sliderAxes = divider.append_axes("bottom", size="5%", pad=0.1)
		self.sliderPercentile = Slider(sliderAxes, "$p$", self.minDL, self.maxDL, valinit=self.maxDL, closedmax=False)
		self.sliderPercentile.on_changed(self.PercentileSliderUpdate)
	
		# ShowAll button
		buttonAxes = self.figure.add_axes([0.8, 0.025, 0.1, 0.04])
		self.showAll = Button(buttonAxes, "Show all")
		self.showAll.on_clicked(self.ShowAllButton)
		
	
	def Redraw(self):
		if self.bin >= 0:
			data = self.image.data[:, :, self.bin]
		else:
			data = np.sum(self.image.data, axis=2) / float(self.image.data.shape[2])
		
		if self.log:
			data = np.log(data)
		
		
		self.img.set_data(data)	
		
		self.img.set_clim(vmin=self.vMin, vmax=self.vMax),
		self.figure.canvas.draw_idle()

	
	# Event Handler:
	def ShowAllButton(self, event):
		self.bin = -1
		self.Redraw()
	
	def TimeSliderUpdate(self, val : float):
		self.bin = int(((val-self.image.tMin)*self.image.data.shape[2])/(self.image.tMax-self.image.tMin))
		self.Redraw()
		
	def PercentileSliderUpdate(self, i : float):
		if self.log:
			self.vMin = self.minDL
			self.vMax = i
		else:
			self.vMin = np.exp(self.minDL)
			self.vMax = np.exp(i)
			
		self.Redraw()
